import torch.nn as nn
import torch
import yaml
from model.MAVEN.simulator import MAVEN
from model.MGN.simulator import MGN
from model.HCMT.simulator import HCMT_MAIN
from model.FIGNet.simulator import FIGNet
from model.GT.simulator import GraphTransformer_MAIN
from model.HOOD.simulator import HOOD_MAIN


class Simulator(nn.Module):

    def __init__(self, config_dir, device = torch.device(f'cpu')) -> None:
        super(Simulator, self).__init__()
        self.device = device
        with open(config_dir, 'r') as file:
            config = yaml.safe_load(file)
        if (config["model_type"] == "MGN"):
            self.model = MGN(config, device)
        elif (config["model_type"] == "HCMT"):
            self.model = HCMT_MAIN(config, device)
        elif (config["model_type"] == "FIGNet"):
            self.model = FIGNet(config, device)
        elif (config["model_type"] == "MAVEN"):
            self.model = MAVEN(config, device)
        elif (config["model_type"] == "GT"):
            self.model = GraphTransformer_MAIN(config, device)
        elif (config["model_type"] == "HOOD"):
            self.model = HOOD_MAIN(config, device)
        else:
            print("ERROR MODEL TYPE!")
            exit(0)
        
    def forward(self, data_list, noise_flag = False):
        return self.model(data_list, noise_flag)
        
    def load_checkpoint(self, ckpdir=None):
        self.model.load_checkpoint(ckpdir)

    def save_checkpoint(self, savedir=None):
        self.model.save_checkpoint(savedir)